"""Plot the spectrum of an input matrix."""

from absl import app
from absl import flags
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils.extmath import randomized_svd

plt.style.use('seaborn')

flags.DEFINE_string('input_file', 'input.npy', 'File with input matrix')
flags.DEFINE_string('output_file', None, 'File with spectrum plot')
flags.DEFINE_string('title', 'Singular values of matrix', 'Title of plot')
flags.DEFINE_integer('rank', None, 'Rank at which to truncate')
FLAGS = flags.FLAGS

def main(argv) -> None:
  mat = np.load(FLAGS.input_file)
  if FLAGS.rank is not None:
    _, s, _ = randomized_svd(mat, n_components=FLAGS.rank)
  else:
    s = np.linalg.svd(mat, compute_uv=False)
  plt.plot(range(len(s)), s)
  plt.ylabel('Singular value', fontsize=16)
  plt.xlabel('Singular value index', fontsize=16)
  plt.title(FLAGS.title, fontsize=20)
  if FLAGS.output_file is not None:
    plt.savefig(FLAGS.output_file)
  else:
    plt.show()


if __name__ == '__main__':
  app.run(main)
